{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "febd4dc3",
   "metadata": {},
   "source": [
    "# Quick Start"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad612c7a",
   "metadata": {},
   "source": [
    "In this part, we provide some basic yet all-in-one examples to show the features of Torch-Pruning. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "72b0a38f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import sys, os\n",
    "sys.path.append(os.path.abspath(\"../\"))\n",
    "\n",
    "import torch\n",
    "from torchvision.models import resnet18\n",
    "import torch_pruning as tp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa6df445",
   "metadata": {},
   "source": [
    "### Method 1. Pruning with DepGraph\n",
    "\n",
    "``DependencyGraph`` serves as the cornerstone of Torch-Pruning, which automatically identifies and groups all layers with inter-dependency. In structural pruning, two layers with dependency should be pruned simultaneously. Therefore, to prune a complicated model, we need to handle those layers carefully. The following example shows the pipeline of pruning single layer in a ResNet-18."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ef2915f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 0. prepare your model and example inputs\n",
    "model = resnet18(pretrained=True).eval()\n",
    "example_inputs = torch.randn(1,3,224,224)\n",
    "\n",
    "# 1. build dependency graph for resnet18\n",
    "DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)\n",
    "\n",
    "# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].\n",
    "pruning_idxs = pruning_idxs=[2, 6, 9]\n",
    "pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )\n",
    "\n",
    "# 3. prune all grouped layer that is coupled with model.conv1\n",
    "if DG.check_pruning_group(pruning_group):\n",
    "    pruning_group.prune()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0517ce0",
   "metadata": {},
   "source": [
    "After invoking the ``.exec`` method, an inplace pruning will be applied to the model. Upon printing the model, we can notice that multiple layers, such as \"model.conv1\", \"model.bn1\", and \"model.layer1[0].conv1\" are pruned by Torch-Pruning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7106239a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After pruning:\n",
      "ResNet(\n",
      "  (conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(\"After pruning:\")\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1de3ee9a",
   "metadata": {},
   "source": [
    "Let's inspect the pruning group. The results will show how a pruning operation triggers (=>) another one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fe52ed58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "--------------------------------\n",
      "          Pruning Group\n",
      "--------------------------------\n",
      "[0] [DEP] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), #Pruned=3\n",
      "[1] [DEP] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #Pruned=3\n",
      "[2] [DEP] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp(ReluBackward0), #Pruned=3\n",
      "[3] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0), #Pruned=3\n",
      "[4] [DEP] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), #Pruned=3\n",
      "[5] [DEP] prune_out_channels on _ElementWiseOp(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=3\n",
      "[6] [DEP] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #Pruned=3\n",
      "[7] [DEP] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), #Pruned=3\n",
      "[8] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_out_channels on _ElementWiseOp(AddBackward0), #Pruned=3\n",
      "[9] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=3\n",
      "[10] [DEP] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #Pruned=3\n",
      "[11] [DEP] prune_out_channels on _ElementWiseOp(AddBackward0) => prune_out_channels on _ElementWiseOp(ReluBackward0), #Pruned=3\n",
      "[12] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), #Pruned=3\n",
      "[13] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), #Pruned=3\n",
      "[14] [DEP] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=3\n",
      "[15] [DEP] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=3\n",
      "--------------------------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(pruning_group)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b573e540",
   "metadata": {},
   "source": [
    "You can also get all groups from ``DependencyGraph`` using the ``get_all_groups`` method. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "99b564ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of Groups: 13\n",
      "The last Group: \n",
      "--------------------------------\n",
      "          Pruning Group\n",
      "--------------------------------\n",
      "[0] [DEP] prune_out_channels on layer4.1.bn1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer4.1.bn1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #Pruned=512\n",
      "[1] [DEP] prune_out_channels on layer4.1.bn1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer4.1.conv1 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=512\n",
      "[2] [DEP] prune_out_channels on layer4.1.bn1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp(ReluBackward0), #Pruned=512\n",
      "[3] [DEP] prune_out_channels on _ElementWiseOp(ReluBackward0) => prune_in_channels on layer4.1.conv2 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), #Pruned=512\n",
      "--------------------------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "all_groups = list(DG.get_all_groups())\n",
    "print(\"Number of Groups: %d\"%len(all_groups))\n",
    "print(\"The last Group:\", all_groups[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "faf9a76f",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb13e0b8",
   "metadata": {},
   "source": [
    "### 2. High-level Pruners\n",
    "\n",
    "Pruning a neural network using the ``DependencyGraph`` can be still complicated, especially for models with numerous layers. Therefore, we also offer high-level pruners to simplify this process. For example, you can easily prune a ResNet18 model with a simple magnitude-based pruner. This method removes weights with small magnitude in the network, resulting in a smaller and faster model without too much performance lost in accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a05fd44b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet(\n",
      "  (conv1): Conv2d(3, 57, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(57, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(57, 115, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(115, 115, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(57, 115, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(115, 115, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(115, 115, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(115, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(115, 230, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(230, 230, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(115, 230, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(230, 230, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(230, 230, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(230, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(230, 460, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(460, 460, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(230, 460, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(460, 460, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(460, 460, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(460, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=460, out_features=1000, bias=True)\n",
      ")\n",
      "torch.Size([1, 1000])\n",
      "  Iter 1/5, Params: 11.69 M => 9.48 M\n",
      "  Iter 1/5, MACs: 1.82 G => 1.47 G\n",
      "ResNet(\n",
      "  (conv1): Conv2d(3, 51, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(51, 102, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(102, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(51, 102, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(102, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(102, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(102, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(102, 204, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(102, 204, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(204, 204, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(204, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(204, 409, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(409, 409, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(204, 409, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(409, 409, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(409, 409, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(409, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=409, out_features=1000, bias=True)\n",
      ")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1000])\n",
      "  Iter 2/5, Params: 11.69 M => 7.53 M\n",
      "  Iter 2/5, MACs: 1.82 G => 1.18 G\n",
      "ResNet(\n",
      "  (conv1): Conv2d(3, 44, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(44, 44, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(44, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(44, 89, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(89, 89, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(44, 89, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(89, 89, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(89, 89, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(89, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(89, 179, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(179, 179, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(89, 179, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(179, 179, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(179, 179, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(179, 358, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(358, 358, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(179, 358, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(358, 358, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(358, 358, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(358, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=358, out_features=1000, bias=True)\n",
      ")\n",
      "torch.Size([1, 1000])\n",
      "  Iter 3/5, Params: 11.69 M => 5.82 M\n",
      "  Iter 3/5, MACs: 1.82 G => 0.91 G\n",
      "ResNet(\n",
      "  (conv1): Conv2d(3, 38, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(38, 38, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(38, 38, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(38, 38, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(38, 38, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(38, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(38, 76, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(76, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(38, 76, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(76, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(76, 76, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(76, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(76, 153, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(153, 153, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(76, 153, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(153, 153, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(153, 153, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(153, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(153, 307, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(307, 307, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(153, 307, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(307, 307, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(307, 307, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(307, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=307, out_features=1000, bias=True)\n",
      ")\n",
      "torch.Size([1, 1000])\n",
      "  Iter 4/5, Params: 11.69 M => 4.32 M\n",
      "  Iter 4/5, MACs: 1.82 G => 0.68 G\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet(\n",
      "  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  (relu): ReLU(inplace=True)\n",
      "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "  (layer1): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer2): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (layer4): Sequential(\n",
      "    (0): BasicBlock(\n",
      "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (downsample): Sequential(\n",
      "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (1): BasicBlock(\n",
      "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      (relu): ReLU(inplace=True)\n",
      "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    )\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "  (fc): Linear(in_features=256, out_features=1000, bias=True)\n",
      ")\n",
      "torch.Size([1, 1000])\n",
      "  Iter 5/5, Params: 11.69 M => 3.06 M\n",
      "  Iter 5/5, MACs: 1.82 G => 0.49 G\n"
     ]
    }
   ],
   "source": [
    "model = resnet18(pretrained=True)\n",
    "example_inputs = torch.randn(1, 3, 224, 224)\n",
    "\n",
    "# 0. importance criterion for parameter selections\n",
    "imp = tp.importance.MagnitudeImportance(p=2, group_reduction='mean')\n",
    "\n",
    "# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.\n",
    "ignored_layers = []\n",
    "for m in model.modules():\n",
    "    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:\n",
    "        ignored_layers.append(m) # DO NOT prune the final classifier!\n",
    "\n",
    "        \n",
    "# 2. Pruner initialization\n",
    "iterative_steps = 5 # You can prune your model to the target pruning ratio iteratively.\n",
    "pruner = tp.pruner.MagnitudePruner(\n",
    "    model, \n",
    "    example_inputs, \n",
    "    global_pruning=False, # If False, a uniform ratio will be assigned to different layers.\n",
    "    importance=imp, # importance criterion for parameter selection\n",
    "    iterative_steps=iterative_steps, # the number of iterations to achieve target ratio\n",
    "    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}\n",
    "    ignored_layers=ignored_layers,\n",
    ")\n",
    "\n",
    "base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)\n",
    "for i in range(iterative_steps):\n",
    "    # 3. the pruner.step will remove some channels from the model with least importance\n",
    "    pruner.step()\n",
    "    \n",
    "    # 4. Do whatever you like here, such as fintuning\n",
    "    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)\n",
    "    print(model)\n",
    "    print(model(example_inputs).shape)\n",
    "    print(\n",
    "        \"  Iter %d/%d, Params: %.2f M => %.2f M\"\n",
    "        % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6)\n",
    "    )\n",
    "    print(\n",
    "        \"  Iter %d/%d, MACs: %.2f G => %.2f G\"\n",
    "        % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9)\n",
    "    )\n",
    "    # finetune your model here\n",
    "    # finetune(model)\n",
    "    # ...\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
